import cv2
import torch

from ..env import Paint
from ..utils.util import *

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class fastenv:
    def __init__(
        self,
        max_episode_length=10,
        env_batch=64,
        writer=None,
        images=None,
        device="cpu",
        Decoder=None,
    ):
        self.max_episode_length = max_episode_length
        self.env_batch = env_batch
        self.device = device
        self.Decoder = Decoder
        self.env = Paint(
            self.env_batch,
            self.max_episode_length,
            device=self.device,
            Decoder=self.Decoder,
        )
        self.env.load_data(images)
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        self.writer = writer
        self.test = False
        self.log = 0

    def save_image(self, log, step):
        for i in range(self.env_batch):
            if self.env.imgid[i] <= 10:
                canvas = cv2.cvtColor(
                    (to_numpy(self.env.canvas[i].permute(1, 2, 0))), cv2.COLOR_BGR2RGB
                )
                self.writer.add_image(
                    "{}/canvas_{}.png".format(str(self.env.imgid[i]), str(step)),
                    canvas,
                    log,
                )
        if step == self.max_episode_length:
            for i in range(self.env_batch):
                if self.env.imgid[i] < 50:
                    gt = cv2.cvtColor(
                        (to_numpy(self.env.gt[i].permute(1, 2, 0))), cv2.COLOR_BGR2RGB
                    )
                    canvas = cv2.cvtColor(
                        (to_numpy(self.env.canvas[i].permute(1, 2, 0))),
                        cv2.COLOR_BGR2RGB,
                    )
                    self.writer.add_image(
                        str(self.env.imgid[i]) + "/_target.png", gt, log
                    )
                    self.writer.add_image(
                        str(self.env.imgid[i]) + "/_canvas.png", canvas, log
                    )

    def step(self, action):
        with torch.no_grad():
            ob, r, d, _ = self.env.step(torch.tensor(action).to(self.device))
        if d[0]:
            if not self.test:
                self.dist = self.get_dist()
                for i in range(self.env_batch):
                    if self.writer:
                        self.writer.add_scalar("train/dist", self.dist[i], self.log)
                    self.log += 1
        return ob, r, d, _

    def get_dist(self):
        return to_numpy(
            (((self.env.gt.float() - self.env.canvas.float()) / 255) ** 2)
            .mean(1)
            .mean(1)
            .mean(1)
        )

    def reset(self, test=False, episode=0):
        self.test = test
        ob = self.env.reset(self.test, episode * self.env_batch)
        return ob
